#ifndef NEW6X3_H
#define NEW6X3_H

constexpr int hintOC = 64;
constexpr int hintIC = 32;

#include <algorithm>
#include <cassert>
#include <immintrin.h>
#include <xmmintrin.h>

using namespace std;

/// Fake iterator indicator for C language
struct Range {
    int start;
    int end;
    Range(): start(-1), end(-1) {}
    Range(int start, int end): start(start), end(end) {}
    inline int size() const { return end - start; }
};

inline int ceildiv(int a, int b) {
    // https://stackoverflow.com/questions/2745074/fast-ceiling-of-an-integer-division-in-c-c
    // return a / b + (a % b ? 1 : 0);
    return (a + b - 1) / b;
}

[[gnu::always_inline]] inline
unsigned char gen_mmask8_for_dump(int i) {
    switch (i) {
        case 0: return 63;
        case 1: return 62;
        case 2: return 60;
        case 3: return 56;
        case 4: return 48;
        case 5: return 32;
        default: assert(false);
    }
}

/// Transpose of eight __m256 vectors for 8x8 matrix
/// @param row Vectors to be transposed, will be junk data after transpose
/// @param tr Transposed vectors
/// @see https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
template<typename Intrinsic> [[maybe_unused, gnu::always_inline]] inline
void mm_transpose_8x8(Intrinsic row[8], Intrinsic tr[8]);

template<> [[gnu::always_inline]] inline
void mm_transpose_8x8<__m256>(__m256 row[8], __m256 tr[8]) {
    tr[0] = _mm256_unpacklo_ps(row[0], row[1]);
    tr[1] = _mm256_unpackhi_ps(row[0], row[1]);
    tr[2] = _mm256_unpacklo_ps(row[2], row[3]);
    tr[3] = _mm256_unpackhi_ps(row[2], row[3]);
    tr[4] = _mm256_unpacklo_ps(row[4], row[5]);
    tr[5] = _mm256_unpackhi_ps(row[4], row[5]);
    tr[6] = _mm256_unpacklo_ps(row[6], row[7]);
    tr[7] = _mm256_unpackhi_ps(row[6], row[7]);
    row[0] = _mm256_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(1, 0, 1, 0));
    row[1] = _mm256_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(3, 2, 3, 2));
    row[2] = _mm256_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(1, 0, 1, 0));
    row[3] = _mm256_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(3, 2, 3, 2));
    row[4] = _mm256_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(1, 0, 1, 0));
    row[5] = _mm256_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(3, 2, 3, 2));
    row[6] = _mm256_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(1, 0, 1, 0));
    row[7] = _mm256_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(3, 2, 3, 2));
    tr[0] = _mm256_permute2f128_ps(row[0], row[4], 0x20);
    tr[1] = _mm256_permute2f128_ps(row[1], row[5], 0x20);
    tr[2] = _mm256_permute2f128_ps(row[2], row[6], 0x20);
    tr[3] = _mm256_permute2f128_ps(row[3], row[7], 0x20);
    tr[4] = _mm256_permute2f128_ps(row[0], row[4], 0x31);
    tr[5] = _mm256_permute2f128_ps(row[1], row[5], 0x31);
    tr[6] = _mm256_permute2f128_ps(row[2], row[6], 0x31);
    tr[7] = _mm256_permute2f128_ps(row[3], row[7], 0x31);
}

/// Transpose of eight __m512 vectors for 8x16 matrix
/// @note This function performs __m512 version of `mm_transpose_8x8_ps`.
///       This function actually does \n
///       +---+---+ \n
///       | T | T | \n
///       +---+---+ \n
/// @param row Vectors to be transposed, will be junk data after transpose
/// @param tr Transposed vectors
/// @see https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2
template<> [[gnu::always_inline]] inline
void mm_transpose_8x8<__m512>(__m512 row[8], __m512 tr[8]) {
    const __m512i i0x20 = _mm512_set_epi32(033, 032, 031, 030, 013, 012, 011, 010, 023, 022, 021, 020, 003, 002, 001, 000);
    const __m512i i0x31 = _mm512_set_epi32(037, 036, 035, 034, 017, 016, 015, 014, 027, 026, 025, 024, 007, 006, 005, 004);
    tr[0] = _mm512_unpacklo_ps(row[0], row[1]);
    tr[1] = _mm512_unpackhi_ps(row[0], row[1]);
    tr[2] = _mm512_unpacklo_ps(row[2], row[3]);
    tr[3] = _mm512_unpackhi_ps(row[2], row[3]);
    tr[4] = _mm512_unpacklo_ps(row[4], row[5]);
    tr[5] = _mm512_unpackhi_ps(row[4], row[5]);
    tr[6] = _mm512_unpacklo_ps(row[6], row[7]);
    tr[7] = _mm512_unpackhi_ps(row[6], row[7]);
    row[0] = _mm512_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(1, 0, 1, 0));
    row[1] = _mm512_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(3, 2, 3, 2));
    row[2] = _mm512_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(1, 0, 1, 0));
    row[3] = _mm512_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(3, 2, 3, 2));
    row[4] = _mm512_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(1, 0, 1, 0));
    row[5] = _mm512_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(3, 2, 3, 2));
    row[6] = _mm512_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(1, 0, 1, 0));
    row[7] = _mm512_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(3, 2, 3, 2));
    tr[0] = _mm512_permutex2var_ps(row[0], i0x20, row[4]);
    tr[1] = _mm512_permutex2var_ps(row[1], i0x20, row[5]);
    tr[2] = _mm512_permutex2var_ps(row[2], i0x20, row[6]);
    tr[3] = _mm512_permutex2var_ps(row[3], i0x20, row[7]);
    tr[4] = _mm512_permutex2var_ps(row[0], i0x31, row[4]);
    tr[5] = _mm512_permutex2var_ps(row[1], i0x31, row[5]);
    tr[6] = _mm512_permutex2var_ps(row[2], i0x31, row[6]);
    tr[7] = _mm512_permutex2var_ps(row[3], i0x31, row[7]);
}

/// Transpose of eight __m256 vectors for 8x3 matrix only
/// @param row Vectors to be transposed, will be junk data after transpose
/// @param tr Transposed vectors, note that only tr_0:3 are useful
/// @see mm_transpose_8x8_ps
template<typename Intrinsic> [[maybe_unused, gnu::always_inline]] inline
void mm_transpose_8x3(Intrinsic row[8], Intrinsic tr[8]);

template<> [[gnu::always_inline]] inline
void mm_transpose_8x3<__m256>(__m256 row[8], __m256 tr[8]) {
    tr[0] = _mm256_unpacklo_ps(row[0], row[1]);
    tr[1] = _mm256_unpackhi_ps(row[0], row[1]);
    tr[2] = _mm256_unpacklo_ps(row[2], row[3]);
    tr[3] = _mm256_unpackhi_ps(row[2], row[3]);
    tr[4] = _mm256_unpacklo_ps(row[4], row[5]);
    tr[5] = _mm256_unpackhi_ps(row[4], row[5]);
    tr[6] = _mm256_unpacklo_ps(row[6], row[7]);
    tr[7] = _mm256_unpackhi_ps(row[6], row[7]);
    row[0] = _mm256_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(1, 0, 1, 0));
    row[1] = _mm256_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(3, 2, 3, 2));
    row[2] = _mm256_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(1, 0, 1, 0));
    row[4] = _mm256_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(1, 0, 1, 0));
    row[5] = _mm256_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(3, 2, 3, 2));
    row[6] = _mm256_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(1, 0, 1, 0));
    tr[0] = _mm256_permute2f128_ps(row[0], row[4], 0x20);
    tr[1] = _mm256_permute2f128_ps(row[1], row[5], 0x20);
    tr[2] = _mm256_permute2f128_ps(row[2], row[6], 0x20);
}

/// Transpose of eight __m512 vectors for 3x16 matrix only
/// @param row Vectors to be transposed, will be junk data after transpose
/// @param tr Transposed vectors, note that only tr_0:3 are useful
/// @see mm_transpose_8x16_block_ps
template<> [[gnu::always_inline]] inline
void mm_transpose_8x3<__m512>(__m512 row[8], __m512 tr[8]) {
    const __m512i i0x20 = _mm512_set_epi32(033, 032, 031, 030, 013, 012, 011, 010, 023, 022, 021, 020, 003, 002, 001, 000);
    tr[0] = _mm512_unpacklo_ps(row[0], row[1]);
    tr[1] = _mm512_unpackhi_ps(row[0], row[1]);
    tr[2] = _mm512_unpacklo_ps(row[2], row[3]);
    tr[3] = _mm512_unpackhi_ps(row[2], row[3]);
    tr[4] = _mm512_unpacklo_ps(row[4], row[5]);
    tr[5] = _mm512_unpackhi_ps(row[4], row[5]);
    tr[6] = _mm512_unpacklo_ps(row[6], row[7]);
    tr[7] = _mm512_unpackhi_ps(row[6], row[7]);
    row[0] = _mm512_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(1, 0, 1, 0));
    row[1] = _mm512_shuffle_ps(tr[0], tr[2], _MM_SHUFFLE(3, 2, 3, 2));
    row[2] = _mm512_shuffle_ps(tr[1], tr[3], _MM_SHUFFLE(1, 0, 1, 0));
    row[4] = _mm512_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(1, 0, 1, 0));
    row[5] = _mm512_shuffle_ps(tr[4], tr[6], _MM_SHUFFLE(3, 2, 3, 2));
    row[6] = _mm512_shuffle_ps(tr[5], tr[7], _MM_SHUFFLE(1, 0, 1, 0));
    tr[0] = _mm512_permutex2var_ps(row[0], i0x20, row[4]);
    tr[1] = _mm512_permutex2var_ps(row[1], i0x20, row[5]);
    tr[2] = _mm512_permutex2var_ps(row[2], i0x20, row[6]);
}

/// Transpose of eight __m512 vectors for 8x16 matrix (8x8 column to 4x16 row)
/// @note This function actually does \n
///       +-------+-------+     +-------+-------+ \n
///       +       +       +     +       1       + \n
///       +   1   +   2   + --> +-------+-------+ \n
///       +       +       +     +       2       + \n
///       +-------+-------+     +-------+-------+ \n
[[gnu::always_inline]] inline
void mm_transpose_8x16_col2row(__m512 row[8], __m512 tr[8]) {
    const __m512i ihi = _mm512_set_epi32(027, 026, 025, 024, 023, 022, 021, 020, 007, 006, 005, 004, 003, 002, 001, 000);
    const __m512i ilo = _mm512_set_epi32(037, 036, 035, 034, 033, 032, 031, 030, 017, 016, 015, 014, 013, 012, 011, 010);
    tr[0] = _mm512_permutex2var_ps(row[0], ihi, row[1]);
    tr[1] = _mm512_permutex2var_ps(row[2], ihi, row[3]);
    tr[2] = _mm512_permutex2var_ps(row[4], ihi, row[5]);
    tr[3] = _mm512_permutex2var_ps(row[6], ihi, row[7]);
    tr[4] = _mm512_permutex2var_ps(row[0], ilo, row[1]);
    tr[5] = _mm512_permutex2var_ps(row[2], ilo, row[3]);
    tr[6] = _mm512_permutex2var_ps(row[4], ilo, row[5]);
    tr[7] = _mm512_permutex2var_ps(row[6], ilo, row[7]);
}

/// Transpose of eight __m512 vectors for 8x16 matrix (4x16 row to 8x8 column)
/// @note This function actually does \n
///       +-------+-------+     +-------+-------+ \n
///       +       1       +     +       +       + \n
///       +-------+-------+ --> +   1   +   2   + \n
///       +       2       +     +       +       + \n
///       +-------+-------+     +-------+-------+ \n
[[gnu::always_inline]] inline
void mm_transpose_8x16_row2col(__m512 row[8], __m512 tr[8]) {
    const __m512i ihi = _mm512_set_epi32(027, 026, 025, 024, 023, 022, 021, 020, 007, 006, 005, 004, 003, 002, 001, 000);
    const __m512i ilo = _mm512_set_epi32(037, 036, 035, 034, 033, 032, 031, 030, 017, 016, 015, 014, 013, 012, 011, 010);
    tr[0] = _mm512_permutex2var_ps(row[0], ihi, row[4]);
    tr[1] = _mm512_permutex2var_ps(row[0], ilo, row[4]);
    tr[2] = _mm512_permutex2var_ps(row[1], ihi, row[5]);
    tr[3] = _mm512_permutex2var_ps(row[1], ilo, row[5]);
    tr[4] = _mm512_permutex2var_ps(row[2], ihi, row[6]);
    tr[5] = _mm512_permutex2var_ps(row[2], ilo, row[6]);
    tr[6] = _mm512_permutex2var_ps(row[3], ihi, row[7]);
    tr[7] = _mm512_permutex2var_ps(row[3], ilo, row[7]);
}

/// Perform winograd F(6x3) transformation for data (B.T @ D)
/// @param D Raw data, for this task should be raw image pixels
/// @param BtD Transformed data
template<typename Intrinsic> [[gnu::always_inline]] inline
void transform_BtD_6x3(const Intrinsic D[8], Intrinsic BtD[8]) {
    Intrinsic s0, s1;
    BtD[0] = D[0] + 5.25f * (D[4] - D[2]) - D[6];
    BtD[7] = D[7] - D[1] + 5.25f * (D[3] - D[5]);
    s0 = D[1] - 4.25f * D[3] + D[5];
    s1 = D[2] - 4.25f * D[4] + D[6];
    BtD[1] = s0 + s1;
    BtD[2] = s1 - s0;
    s0 = 0.5f * D[1] - 2.5f * D[3] + 2.f * D[5];
    s1 = D[6] + 0.25f * D[2] - 1.25f * D[4];
    BtD[3] = s0 + s1;
    BtD[4] = s1 - s0;
    s0 = 2.f * D[1] - 2.5f * D[3] + 0.5f * D[5];
    s1 = D[6] + 4.f * D[2] - 5.f * D[4];
    BtD[5] = s0 + s1;
    BtD[6] = s1 - s0;
}

/// Perform winograd F(6x3) transformation for filter (G @ D)
/// @param D Raw filter
/// @param GD Transformed data
template<typename Intrinsic> [[gnu::always_inline]] inline
void transform_GD_6x3(const Intrinsic D[3], Intrinsic GD[8]) {
    Intrinsic s0, s1;
    GD[0] = D[0];
    GD[7] = D[2];
    s0 = -2.f/9.f * (D[0] + D[2]);
    s1 = -2.f/9.f * D[1];
    GD[1] = s0 + s1;
    GD[2] = s0 - s1;
    s0 = 1.f/90.f * D[0] + 2.f/45.f * D[2];
    s1 = 1.f/45.f * D[1];
    GD[3] = s0 + s1;
    GD[4] = s0 - s1;
    s0 = 32.f/45.f * D[0] + 8.f/45.f * D[2];
    s1 = 16.f/45.f * D[1];
    GD[5] = s0 + s1;
    GD[6] = s0 - s1;
}

/// Perform Winograd F(6x3) transformation for final output (A.T @ D)
/// @param D Pixels which channals have been summed, and filter dimension remains
/// @param AtD Transformed data
/// @see transform_AtD_6x3
template<typename Intrinsic> [[gnu::always_inline]] inline
void transform_AtD_6x3(const Intrinsic D[8], Intrinsic AtD[6]) {
    Intrinsic s0, s1, s2;
    s0 = D[1] + D[2];
    s1 = D[3] + D[4];
    s2 = D[5] + D[6];
    AtD[0] = s0 + s1 + s2 + D[0];
    AtD[2] = s0 + 4.f * s1 + 0.25f * s2;
    AtD[4] = s0 + 16.f * s1 + 0.0625f * s2;
    s0 = D[1] - D[2];
    s1 = D[3] - D[4];
    s2 = D[5] - D[6];
    AtD[1] = s0 + 2.f * s1 + 0.5f * s2;
    AtD[3] = s0 + 8.f * s1 + 0.125f * s2;
    AtD[5] = s0 + 32.f * s1 + 0.03125f * s2 + D[7];
}

extern "C"
void winconv(const float *__restrict__ image, int IH,
             int IW, int IC, const float *__restrict__ filter,
             int OC, int N, float *__restrict__ result);

#endif //NEW6X3_H
